Skip to content

feat: add protocols for sample proposal and accept/reject#1722

Merged
janfb merged 2 commits intomainfrom
add-protocols-in-accep-reject
Feb 6, 2026
Merged

feat: add protocols for sample proposal and accept/reject#1722
janfb merged 2 commits intomainfrom
add-protocols-in-accep-reject

Conversation

@janfb
Copy link
Contributor

@janfb janfb commented Jan 9, 2026

Closes #1395

Context

In #1370, the proposal argument in accept_reject_sample was changed from torch.Distribution to Callable to support passing sample functions directly. This made the type too permissive - there's no enforcement that the callable has the expected signature.

Motivation

Replace generic Callable types with Python Protocols to enable static type checking of the expected signatures. This catches mismatches at type-check time rather than runtime.

Changes

  • Added SampleProposal protocol: (sample_shape: torch.Size, **kwargs) -> Tensor
  • Added AcceptRejectFn protocol: (theta: Tensor) -> Tensor
  • Updated accept_reject_sample to use these protocols
  • Fixed call site to use torch.Size() instead of plain tuple (removed # type: ignore)
  • Added **kwargs to _sample_via_diffusion in vector_field_posterior.py for protocol compliance

@codecov
Copy link

codecov bot commented Jan 9, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 84.71%. Comparing base (e71d0b4) to head (01f47ff).
⚠️ Report is 11 commits behind head on main.
✅ All tests successful. No failed tests found.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1722      +/-   ##
==========================================
- Coverage   88.51%   84.71%   -3.81%     
==========================================
  Files         137      137              
  Lines       11527    11518       -9     
==========================================
- Hits        10203     9757     -446     
- Misses       1324     1761     +437     
Flag Coverage Δ
fast 84.71% <100.00%> (?)
full ?

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/inference/posteriors/vector_field_posterior.py 69.12% <ø> (-8.06%) ⬇️
sbi/samplers/rejection/rejection.py 92.43% <100.00%> (-2.81%) ⬇️
sbi/sbi_types.py 100.00% <100.00%> (ø)
sbi/utils/restriction_estimator.py 76.31% <100.00%> (-8.65%) ⬇️

... and 30 files with indirect coverage changes

@janfb janfb force-pushed the add-protocols-in-accep-reject branch from 28c58d3 to 0d50d00 Compare January 13, 2026 15:46
@dgedon dgedon self-requested a review January 24, 2026 06:11
Copy link
Collaborator

@dgedon dgedon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully get why this change is necessary and whether it is implemented in a way that has lots of code overhead for limited gain. Maybe you can clarify.

@janfb
Copy link
Contributor Author

janfb commented Jan 28, 2026

Thanks for the review @dgedon

To clarify: the PR description refers to an earlier change (PR #1370) that converted proposal from torch.Distribution to Callable. This PR is about making that Callable type more precise with Protocols.

On whether Protocols are necessary vs Callable[[torch.Size], Tensor]: the key issue is that SampleProposal needs to accept **kwargs. Looking at actual usage:

  candidates = proposal(                                                                                                                                             
      torch.Size((sampling_batch_size,)),                                                                                                                            
      **proposal_sampling_kwargs,  # passes {"condition": x}, {"predictor": ..., "corrector": ...}, etc.                                                             
  )

The proposals (posterior_estimator.sample, _sample_via_diffusion, etc.) all require additional kwargs like condition. You can't express (sample_shape: torch.Size, **kwargs) -> Tensor with a simple Callable type hint — that's the main reason for the Protocol.

For AcceptRejectFn, you're right that Callable[[Tensor], Tensor] would technically suffice since it's only ever called with theta. I kept it as a Protocol for consistency and more explicit documentation, but happy to simplify if you prefer.

@janfb janfb requested review from dgedon and removed request for gmoss13 and manuelgloeckler February 2, 2026 09:11
@janfb janfb self-assigned this Feb 2, 2026
Copy link
Collaborator

@dgedon dgedon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the answers, it clarifies it well. Good to merge

@janfb janfb merged commit fa4d7a9 into main Feb 6, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

constrain the type of proposal in accep_reject_sample

2 participants